/*
 * Copyright  2017 NXP
 * All rights reserved.
 *
 *
 * SPDX-License-Identifier: BSD-3-Clause
 */

#include <stdio.h>

#include "fsl_common.h"
#include "board.h"
#include "pin_mux.h"

#include "fc_layer.h"
#include "arm_math.h"

/*******************************************************************************
 * Definitions
 ******************************************************************************/

/*******************************************************************************
 * Variables
 ******************************************************************************/
/* 输入样本, 在跟随系统里, 同时也是输出特征. */
#define APP_FC_DIM_IN  30u
#define APP_FC_SAMPLES 200u
extern float app_fc_layer_input_data[APP_FC_SAMPLES * APP_FC_DIM_IN];

/* 构建一个4层网络. 前级的输出直接作为后级的输入.*/
#define APP_FC_LAYER0_DIM_IN   APP_FC_DIM_IN
#define APP_FC_LAYER0_DIM_OUT  16u
#define APP_FC_LAYER1_DIM_IN   APP_FC_LAYER0_DIM_OUT
#define APP_FC_LAYER1_DIM_OUT  8u
#define APP_FC_LAYER2_DIM_IN   APP_FC_LAYER1_DIM_OUT
#define APP_FC_LAYER2_DIM_OUT  8u
#define APP_FC_LAYER3_DIM_IN   APP_FC_LAYER2_DIM_OUT
#define APP_FC_LAYER3_DIM_OUT  APP_FC_DIM_IN

#define APP_FC_DIM_OUT  APP_FC_LAYER3_DIM_OUT

float app_fc_layer0_output[APP_FC_LAYER0_DIM_OUT];
float app_fc_layer1_output[APP_FC_LAYER1_DIM_OUT];
float app_fc_layer2_output[APP_FC_LAYER2_DIM_OUT];
float app_fc_layer3_output[APP_FC_LAYER3_DIM_OUT];
float app_fc_output[APP_FC_DIM_OUT]; /* 将在fc_infer()函数执行之后被赋予app_fc_layer3_output. 算了, 直接复制出一份吧. */

/* 网络本身的神经元参数. */
/* 本级权重. */
float app_fc_layer0_W[APP_FC_LAYER0_DIM_OUT * APP_FC_LAYER0_DIM_IN];
float app_fc_layer1_W[APP_FC_LAYER1_DIM_OUT * APP_FC_LAYER1_DIM_IN];
float app_fc_layer2_W[APP_FC_LAYER2_DIM_OUT * APP_FC_LAYER2_DIM_IN];
float app_fc_layer3_W[APP_FC_LAYER3_DIM_OUT * APP_FC_LAYER3_DIM_IN];

/* 本级偏置. */
float app_fc_layer0_B[APP_FC_LAYER0_DIM_OUT];
float app_fc_layer1_B[APP_FC_LAYER1_DIM_OUT];
float app_fc_layer2_B[APP_FC_LAYER2_DIM_OUT];
float app_fc_layer3_B[APP_FC_LAYER3_DIM_OUT];

/* 为训练网络开辟的缓存区. */
/* 反馈给本级权重. */
float app_fc_layer0_dW[APP_FC_LAYER0_DIM_OUT * APP_FC_LAYER0_DIM_IN];
float app_fc_layer1_dW[APP_FC_LAYER1_DIM_OUT * APP_FC_LAYER1_DIM_IN];
float app_fc_layer2_dW[APP_FC_LAYER2_DIM_OUT * APP_FC_LAYER2_DIM_IN];
float app_fc_layer3_dW[APP_FC_LAYER3_DIM_OUT * APP_FC_LAYER3_DIM_IN];

/* 反馈给本级偏置. */
float app_fc_layer0_dB[APP_FC_LAYER0_DIM_OUT];
float app_fc_layer1_dB[APP_FC_LAYER1_DIM_OUT];
float app_fc_layer2_dB[APP_FC_LAYER2_DIM_OUT];
float app_fc_layer3_dB[APP_FC_LAYER3_DIM_OUT];

/* 反馈给前级别的输出. */
float app_fc_layer0_dO[APP_FC_LAYER0_DIM_OUT];
float app_fc_layer1_dO[APP_FC_LAYER1_DIM_OUT];
float app_fc_layer2_dO[APP_FC_LAYER2_DIM_OUT];
//float app_fc_layer3_dO[APP_FC_LAYER3_DIM_OUT]; /* app_fc_dO[APP_FC_DIM_OUT] */
float app_fc_output_error[APP_FC_DIM_OUT];

/* 创建网络各层的实例. */
fc_layer_t app_fc_layer0;
fc_layer_t app_fc_layer1;
fc_layer_t app_fc_layer2;
fc_layer_t app_fc_layer3;

/*******************************************************************************
 * Prototypes
 ******************************************************************************/

void fc_init(void);
void fc_infer(float *input, float *output);
void fc_backprop(float *input, float *output_error);
void fc_update(void);
float fc_get_loss(float *error);
void fc_print(void);

/*******************************************************************************
 * Code
 ******************************************************************************/

void my_delay(uint32_t t)
{
    for (uint32_t i = 0u; i < t; i++)
    {
        for (uint32_t j = 0u; j < 50000u; j++)
        {
            asm("nop");
        }
    }
}


/*!
 * @brief Main function
 */
int main(void)
{
    board_init();
    printf("nn_fc_self_training example.\r\n");
    printf("build on %s, %s\r\n", __DATE__, __TIME__);

    fc_init();

    /* 训练5000次. 样本库里有200个样本, 也就是让模型将整个样本库念25遍. */
    for (uint32_t epoch = 0u; epoch < 25u; epoch++)
    {
        float *app_fc_input = app_fc_layer_input_data;
        float app_fc_loss_sum = 0.0f;

        /* 遍历一次样本库. */
        for (uint32_t index = 0u; index < APP_FC_SAMPLES; index++)
        {
            /* 先试着算一下. */
            fc_infer(app_fc_input, app_fc_output);

            /* 计算偏差. */
            arm_sub_f32(app_fc_input, app_fc_output, app_fc_output_error, APP_FC_DIM_OUT);
            app_fc_loss_sum += fc_get_loss(app_fc_output_error);

            /* 从错误中学习. */
            fc_backprop(app_fc_input, app_fc_output_error); /* input而产生了output_error的偏差, 网络要反思一下. */

            /* 积累到一定程度进行调整. */
            if ((index % 20) == 0u)
            {
                fc_update();
            }

            /* 准备学习下一个样本. */
            app_fc_input += APP_FC_DIM_IN;
        }

        printf("[%3d]loss: %f\r\n", epoch, app_fc_loss_sum);
    }

    printf("training done.\r\n");
    fc_print();

    while (1)
    {
    }
}

#define APP_FC_RAND_MIN  (-0.5f)
#define APP_FC_RAND_MAX  ( 0.5f)

void fc_init(void)
{
    /* layer0. */
    arm_rand_f32(app_fc_layer0_W, APP_FC_RAND_MIN, APP_FC_RAND_MAX, APP_FC_LAYER0_DIM_OUT * APP_FC_LAYER0_DIM_IN);
    arm_rand_f32(app_fc_layer0_B, APP_FC_RAND_MIN, APP_FC_RAND_MAX, APP_FC_LAYER0_DIM_OUT);
    fc_layer_new_infer(&app_fc_layer0, APP_FC_LAYER0_DIM_IN, APP_FC_LAYER0_DIM_OUT, app_fc_layer0_W, app_fc_layer0_B);
    fc_layer_new_backprop(&app_fc_layer0, APP_FC_LAYER0_DIM_IN, APP_FC_LAYER0_DIM_OUT, app_fc_layer0_dW, app_fc_layer0_dB);
    fc_layer_clear_gradients(&app_fc_layer0);

    /* layer1. */
    arm_rand_f32(app_fc_layer1_W, APP_FC_RAND_MIN, APP_FC_RAND_MAX, APP_FC_LAYER1_DIM_OUT * APP_FC_LAYER1_DIM_IN);
    arm_rand_f32(app_fc_layer1_B, APP_FC_RAND_MIN, APP_FC_RAND_MAX, APP_FC_LAYER1_DIM_OUT);
    fc_layer_new_infer(&app_fc_layer1, APP_FC_LAYER1_DIM_IN, APP_FC_LAYER1_DIM_OUT, app_fc_layer1_W, app_fc_layer1_B);
    fc_layer_new_backprop(&app_fc_layer1, APP_FC_LAYER1_DIM_IN, APP_FC_LAYER1_DIM_OUT, app_fc_layer1_dW, app_fc_layer1_dB);
    fc_layer_clear_gradients(&app_fc_layer1);

    /* layer2. */
    arm_rand_f32(app_fc_layer2_W, APP_FC_RAND_MIN, APP_FC_RAND_MAX, APP_FC_LAYER2_DIM_OUT * APP_FC_LAYER2_DIM_IN);
    arm_rand_f32(app_fc_layer2_B, APP_FC_RAND_MIN, APP_FC_RAND_MAX, APP_FC_LAYER2_DIM_OUT);
    fc_layer_new_infer(&app_fc_layer2, APP_FC_LAYER2_DIM_IN, APP_FC_LAYER2_DIM_OUT, app_fc_layer2_W, app_fc_layer2_B);
    fc_layer_new_backprop(&app_fc_layer2, APP_FC_LAYER2_DIM_IN, APP_FC_LAYER2_DIM_OUT, app_fc_layer2_dW, app_fc_layer2_dB);
    fc_layer_clear_gradients(&app_fc_layer2);

    /* layer3. */
    arm_rand_f32(app_fc_layer3_W, APP_FC_RAND_MIN, APP_FC_RAND_MAX, APP_FC_LAYER3_DIM_OUT * APP_FC_LAYER3_DIM_IN);
    arm_rand_f32(app_fc_layer3_B, APP_FC_RAND_MIN, APP_FC_RAND_MAX, APP_FC_LAYER3_DIM_OUT);
    fc_layer_new_infer(&app_fc_layer3, APP_FC_LAYER3_DIM_IN, APP_FC_LAYER3_DIM_OUT, app_fc_layer3_W, app_fc_layer3_B);
    fc_layer_new_backprop(&app_fc_layer3, APP_FC_LAYER3_DIM_IN, APP_FC_LAYER3_DIM_OUT, app_fc_layer3_dW, app_fc_layer3_dB);
    fc_layer_clear_gradients(&app_fc_layer3);
}

void fc_infer(float *input, float *output)
{
    fc_layer_infer(&app_fc_layer0, input, app_fc_layer0_output);
    fc_layer_infer(&app_fc_layer1, app_fc_layer0_output, app_fc_layer1_output);
    fc_layer_infer(&app_fc_layer2, app_fc_layer1_output, app_fc_layer2_output);
    fc_layer_infer(&app_fc_layer3, app_fc_layer2_output, app_fc_layer3_output);

    //output = app_fc_layer3_output; /* 指针传递. 注意app_fc_layer3_output的值在反推的时候还会用到, 所以不要直接修改 */
    for (uint32_t i = 0u; i < app_fc_layer3.dim_out; i++)
    {
        output[i] = app_fc_layer3_output[i];
    }
}

void fc_backprop(float *input, float *output_error)
{
    //float *app_fc_layer3_dO = output_error;
    fc_layer_backprop(&app_fc_layer3, app_fc_layer3_output, output_error, app_fc_layer2_output, app_fc_layer2_dO);
    fc_layer_backprop(&app_fc_layer2, app_fc_layer2_output, app_fc_layer2_dO, app_fc_layer1_output, app_fc_layer1_dO);
    fc_layer_backprop(&app_fc_layer1, app_fc_layer1_output, app_fc_layer1_dO, app_fc_layer0_output, app_fc_layer0_dO);
    fc_layer_backprop(&app_fc_layer0, app_fc_layer0_output, app_fc_layer0_dO, input, NULL);
}

void fc_update(void)
{
    fc_layer_update(&app_fc_layer0, 0.005f);
    fc_layer_update(&app_fc_layer1, 0.005f);
    fc_layer_update(&app_fc_layer2, 0.005f);
    fc_layer_update(&app_fc_layer3, 0.005f);
}

float fc_get_loss(float *error)
{
    float loss_sum = 0.0f;

    for (uint32_t i = 0u; i < APP_FC_DIM_OUT; i++)
    {
        loss_sum += ( (*error) * (*error) );
        error++;
    }

    return loss_sum;
}

void fc_layer_print(fc_layer_t *layer, char *name)
{
    uint32_t i_out, i_in;

    printf("%s: dim_in = %d, dim_out = %d\r\n", name, layer->dim_in, layer->dim_out);

    /* 权重. */
    printf("W =\r\n[\r\n");
    for (i_out = 0u; i_out < layer->dim_out; i_out++)
    {
        printf("    [ ");
        for (i_in = 0u; i_in < layer->dim_in-1u; i_in++)
        {
            printf("%6.3f,", layer->W[i_out * layer->dim_in + i_in]);
        }
        printf("%6.3f ]\r\n", layer->W[i_out * layer->dim_in + i_in]);
    }
    printf("]\r\n");

    /* 偏置. */
    printf("B =\r\n[");
    for (i_out = 0u; i_out < layer->dim_out-1u; i_out++)
    {
        printf("%6.3f,", layer->B[i_out]);
    }
    printf("%6.3f ]\r\n", layer->B[i_out]);
}

void fc_print(void)
{
    fc_layer_print(&app_fc_layer0, "app_fc_layer0"); printf("\r\n");
    fc_layer_print(&app_fc_layer1, "app_fc_layer1"); printf("\r\n");
    fc_layer_print(&app_fc_layer2, "app_fc_layer2"); printf("\r\n");
    fc_layer_print(&app_fc_layer3, "app_fc_layer3"); printf("\r\n");
}

/* EOF. */

